{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Populating the interactive namespace from numpy and matplotlib\n"
     ]
    }
   ],
   "source": [
    "%pylab inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from scipy.ndimage import convolve\n",
    "from sklearn import linear_model, datasets, metrics\n",
    "from sklearn.neural_network import BernoulliRBM\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.grid_search import GridSearchCV "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Preprocess the data set using convolve feature."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def nudge_dataset(X, Y):\n",
    "    direction_vectors = [\n",
    "        [[0, 1, 0],\n",
    "         [0, 0, 0],\n",
    "         [0, 0, 0]],\n",
    "\n",
    "        [[0, 0, 0],\n",
    "         [1, 0, 0],\n",
    "         [0, 0, 0]],\n",
    "\n",
    "        [[0, 0, 0],\n",
    "         [0, 0, 1],\n",
    "         [0, 0, 0]],\n",
    "\n",
    "        [[0, 0, 0],\n",
    "         [0, 0, 0],\n",
    "         [0, 1, 0]]]\n",
    "    \n",
    "    matrix_size = (int(sqrt(X.shape[1])),) * 2\n",
    "\n",
    "    shift = lambda x, w: convolve(x.reshape(matrix_size), mode='constant',\n",
    "                                  weights=w).ravel()\n",
    "    X = np.concatenate([X] +\n",
    "                       [np.apply_along_axis(shift, 1, X, vector)\n",
    "                        for vector in direction_vectors])\n",
    "    Y = np.concatenate([Y for _ in range(5)], axis=0)\n",
    "    return X, Y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Prepare the data set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from util.dataset import read_data_set\n",
    "from util.image import half"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "X_train, y_train = (asarray(d) for d in read_data_set('train', half))\n",
    "X_train, y_train = nudge_dataset(X_train, y_train)\n",
    "X_train = (X_train - np.min(X_train, 0)) / (np.max(X_train, 0) + 0.0001)  # 0-1 scaling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "X_test, y_test = (asarray(d) for d in read_data_set('t10k', half))\n",
    "X_test, y_test = nudge_dataset(X_test, y_test)\n",
    "X_test = (X_test - np.min(X_test, 0)) / (np.max(X_test, 0) + 0.0001)  # 0-1 scaling"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Constructing the convolve neural network."
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "logistic = linear_model.LogisticRegression()\n",
    "logistic.C = 6000.0"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "rbm = BernoulliRBM(random_state=0, verbose=True)\n",
    "rbm.learning_rate = 0.06\n",
    "rbm.n_iter = 20\n",
    "rbm.n_components = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "classifier = Pipeline(steps=[\n",
    "        ('rbm', BernoulliRBM(random_state=0, verbose=True)),\n",
    "        ('logistic', linear_model.LogisticRegression())\n",
    "    ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 3 folds for each of 8 candidates, totalling 24 fits\n"
     ]
    }
   ],
   "source": [
    "grid_search = GridSearchCV(classifier, dict(\n",
    "        rbm__learning_rate=(0.05, 0.1),\n",
    "        rbm__n_iter=(10, 30),\n",
    "        logistic__C=(5000, 7000)\n",
    "    ), n_jobs = 2, verbose=1)\n",
    "grid_search.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get the best parameter:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "best_parameters = grid_search.best_estimator_.get_params()\n",
    "for item in best_parameters.items():\n",
    "    print('%s: %s' % item)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Training the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "classifier.set_params(*best_parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "classifier.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print(\"Logistic regression using RBM features:\\n%s\\n\" % (\n",
    "    metrics.classification_report(\n",
    "        y_test,\n",
    "        classifier.predict(X_test))))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.4.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
